import copy
import pickle
import random
import torch
import numpy as np
import random
import torchvision.transforms as transforms
from numpy import uint8

import matplotlib.pyplot as plt
from PIL import Image

from skimage.transform import resize

from Causal_MNIST_Images.DigitImageGeneration.morphomnist import io
from ModularUtils.DigitImageGeneration.ColoringMnist import color_grayscale_arr
from ModularUtils.DigitImageGeneration.mnist_image_generation import label_to_digit_image
from ModularUtils.DigitImageGeneration.morphomnist import perturb, morpho


def test_result_data(Exp, intv_no):
    print("TESTING")
    loaded_images = io.load_idx(
        f"{Exp.file_roots[intv_no]}digitimages.gz")
    labels_data = io.load_idx(
        f"{Exp.file_roots[intv_no]}digitlabels.gz")

    transform = transforms.Compose([transforms.ToPILImage(),
                                    transforms.ToTensor(),
                                    ])

    digit_images = [torch.unsqueeze(transform(img), dim=0).to(Exp.DEVICE) for img in loaded_images]
    digit_images = torch.cat(digit_images, 0)


    for id, img in enumerate(digit_images):

        print(labels_data[id])
        imggg1 = img.permute(1, 2, 0).detach().cpu().numpy()
        fig, ax = plt.subplots()
        plt.imshow(imggg1)
        plt.show()

def check_any_digit(images_data, labels_data):
    dig_indices = {}
    for dig in range(10):
        dig_indices[dig] = []
    for id, dig in enumerate(labels_data):
        dig_indices[dig[0]].append(id)

    digit = 7
    for id in range(len(dig_indices[digit])):
        imgid = dig_indices[digit][id]
        im = Image.fromarray(images_data[imgid]).convert('RGB')
        img_filename = f"SAVED_EXPERIMENTS" \
                       f"/mnist_addition_graph/preprocessed_dataset/result_images/{str(digit)}_{id}.jpeg"
        im.save(img_filename)

        print(f"{id} image of {digit} is saved")
    print("done")

def plot_trained_digits(rows, columns, images, title):
  fig = plt.figure(figsize=(13, 8))
  # columns = 6
  # rows = 3
  # ax enables access to manipulate each of subplots
  ax = []

  for i in range(columns * rows):
    img = images[i]
    # create subplot and append to ax
    ax.append(fig.add_subplot(rows, columns, i + 1))
    # ax[-1].set_title("Label: " + str(label))  # set title
    ax[-1].set_title(title)  # set title

    plt.imshow(img)

  plt.show()


def plot_dataset_digits(image):
    fig = plt.figure(figsize=(13, 8))
    columns = 1
    rows = 1
    # ax enables access to manipulate each of subplots
    ax = []

    for i in range(columns * rows):
        img, label = image
        # create subplot and append to ax
        ax.append(fig.add_subplot(rows, columns, i + 1))
        ax[-1].set_title("Label: " + str(label))  # set title

        plt.imshow(img)

    plt.show()


    def label_to_digit_image(images_data, labels_data, digit, color_id, thickness, IMAGE_SIZE):
        dig_indices = {}
        for dig in range(10):
            dig_indices[dig] = []
        for id, dig in enumerate(labels_data):
            dig_indices[dig].append(id)

        dig_image_id = random.sample(dig_indices[digit], 1)[0]

        # print("Before perturbation")
        # area, length, rlthickness, slant, width, height = measure_image(images_data[dig_image_id])

        perturbations = (
            # lambda m: m.binary_image,  # No perturbation
            perturb.Thinning(amount=0.6),
            perturb.Thickening(amount=0.3),
        )

        change_shape = perturbations[thickness]

        resized_image = resize(images_data[dig_image_id], (IMAGE_SIZE, IMAGE_SIZE))

        morphology = morpho.ImageMorphology(resized_image, scale=4)
        perturbed_hires_image = change_shape(morphology)
        perturbed_images = morphology.downscale(perturbed_hires_image)

        colored_arr = color_grayscale_arr(perturbed_images, color=["red", "green", "blue"][color_id])

        return colored_arr


def produce_result_image(Exp, intv_no, num_images, SAVE_DATASET):
    result_dataset = []
    for label in ["Ydigit1", "Ydigit2", "Ycolor", "Ythick"]:
        file_name = Exp.file_roots[intv_no] + label + "feature" + ".pkl"
        with open(file_name, 'rb') as fp:
            label_data = pickle.load(fp)
        label_data = torch.FloatTensor(label_data)
        label_size = len(label_data)
        result_dataset.append(label_data.view(label_size, 1))

    result_dataset = torch.cat(result_dataset, 1).to(Exp.DEVICE)
    print(result_dataset.shape)


    #
    lbl_dataset = []
    for label in ["X1", "X2", "W"]:
        file_name = Exp.file_roots[intv_no] + label + "feature" + ".pkl"
        with open(file_name, 'rb') as fp:
            label_data = pickle.load(fp)
        label_data = torch.FloatTensor(label_data)
        label_size = len(label_data)
        lbl_dataset.append(label_data.view(label_size, 1))

    lbl_dataset = torch.cat(lbl_dataset, 1).to(Exp.DEVICE)
    print(lbl_dataset.shape)
    #





    images_data = io.load_idx(
        "/path_to_project/CausalMNISTAddition/input_dir/train-images-idx3-ubyte.gz")
    labels_data = io.load_idx(
        "/path_to_project/CausalMNISTAddition/input_dir/train-labels-idx1-ubyte.gz")

    # num_images = result_dataset.shape[0]
    perturbed_images_digit1 = np.zeros((num_images, Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
    perturbed_images_digit2 = np.zeros((num_images, Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
    perturbing_labels_digit1 = np.zeros((num_images, 3))
    perturbing_labels_digit2 = np.zeros((num_images, 3))


    for iter in range(num_images):
        sample = result_dataset[iter, :].detach().cpu().numpy().astype(int)
        print(f"sample no {iter}, label:{lbl_dataset[iter, :]},  image:{sample}")


        digit1, color1, thick1 = sample[0], sample[2], sample[3],    ##digit1 with some properties.
        print(digit1, color1, thick1)
        perturbed_images_digit1[iter] = label_to_digit_image(Exp, images_data, labels_data, digit=digit1, color_id=color1, thickness=thick1)
        perturbing_labels_digit1[iter] = [digit1, color1, thick1]


        # digit2color= random.randint(0, 2)
        # digit2, color2, thick2 = sample[1], digit2color, 1-sample[3],    #digit2 with opposite properties. color dim= 3, thick dim=2,
        # print(digit2, color2, thick2)
        # perturbed_images_digit2[iter] = label_to_digit_image(iter, images_data, labels_data, digit=digit2, color_id=color2, thickness=thick2)
        # perturbing_labels_digit2[iter] = [digit2, color2, thick2]

        # break

        # if iter%100<=4:
        # if digit1>3:
        if iter % 1000 <= 5:
            imgg = perturbed_images_digit1[iter]
            plot_dataset_digits((imgg, perturbing_labels_digit1[iter]))
            transform = transforms.Compose([transforms.ToPILImage(),
                                            transforms.ToTensor(),
                                            ])  # ToTensor is needed for tansh

            digit_images = torch.squeeze(transform(imgg.astype(uint8)), dim=0).to(Exp.DEVICE)
            imggg1 = digit_images.permute(1, 2, 0).detach().cpu().numpy()
            plt.imshow(imggg1)
            plt.show()


    if Exp.image_labels[0] in SAVE_DATASET: #Ydigit1image
        io.save_idx(perturbed_images_digit1, Exp.file_roots[intv_no] + "Ydigit1images.gz")
        io.save_idx(perturbing_labels_digit1, Exp.file_roots[intv_no] + "Ydigit1labels.gz")
        print("image saved at", Exp.file_roots[intv_no] + "Ydigit1images.gz")
        print("labels saved at", Exp.file_roots[intv_no] + "Ydigit1labels.gz")

    # if Exp.image_labels[1] in SAVE_DATASET:
    #     io.save_idx(perturbed_images_digit2, Exp.file_roots[intv_no] + Exp.image_labels[0]+".gz")
    #     io.save_idx(perturbing_labels_digit2, Exp.file_roots[intv_no] + Exp.image_labels[0]+"labels.gz")





def produce_uniform_images(Exp, intv_no, num_images, SAVE_DATASET):


    images_data = io.load_idx(
        "/path_to_project/CausalMNISTAddition/input_dir/train-images-idx3-ubyte.gz")
    labels_data = io.load_idx(
        "/path_to_project/CausalMNISTAddition/input_dir/train-labels-idx1-ubyte.gz")

    # num_images = result_dataset.shape[0]
    perturbed_images_digit = copy.deepcopy(io.load_idx(Exp.file_roots[-1] + "digitimages.gz"))
    perturbing_labels_digit = copy.deepcopy(io.load_idx(Exp.file_roots[-1] + "digitlabels.gz"))
    # perturbed_images_digit = np.zeros((num_images, Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
    # perturbing_labels_digit = np.zeros((num_images, 3))


    digits= torch.randint(0, 10, (num_images,1))
    colors= torch.randint(0, 3, (num_images,1))
    thickness= torch.randint(0, 2, (num_images,1))

    result_dataset= torch.cat([digits, colors, thickness], 1)


    for iter in range(39000, num_images):
        sample = result_dataset[iter, :].detach().cpu().numpy().astype(int)
        print(f"sample no {iter},  image:{sample}")


        digit1, color1, thick1 = sample[0], sample[1], sample[2]
        # print(digit1, color1, thick1)
        perturbed_images_digit[iter] = label_to_digit_image(Exp, images_data, labels_data, digit=digit1, color_id=color1, thickness=thick1)
        perturbing_labels_digit[iter] = [digit1, color1, thick1]


        if iter % 1000 <= 2:
            imgg = perturbed_images_digit[iter]
            plot_dataset_digits((imgg, perturbing_labels_digit[iter]))
            transform = transforms.Compose([transforms.ToPILImage(),
                                            transforms.ToTensor(),
                                            ])  # ToTensor is needed for tansh

            digit_images = torch.squeeze(transform(imgg.astype(uint8)), dim=0).to(Exp.DEVICE)
            imggg1 = digit_images.permute(1, 2, 0).detach().cpu().numpy()
            plt.imshow(imggg1)
            plt.show()


            if SAVE_DATASET:
                io.save_idx(perturbed_images_digit, Exp.file_roots[intv_no] + "digitimages.gz")
                io.save_idx(perturbing_labels_digit, Exp.file_roots[intv_no] + "digitlabels.gz")
                print("image saved at", Exp.file_roots[intv_no] + "digitimages.gz")
                print("labels saved at", Exp.file_roots[intv_no] + "digitlabels.gz")




    if SAVE_DATASET:
        io.save_idx(perturbed_images_digit, Exp.file_roots[intv_no] + "digitimages.gz")
        io.save_idx(perturbing_labels_digit, Exp.file_roots[intv_no] + "digitlabels.gz")
        print("image saved at", Exp.file_roots[intv_no] + "digitimages.gz")
        print("labels saved at", Exp.file_roots[intv_no] + "digitlabels.gz")
    # if Exp.image_labels[1] in SAVE_DATASET:
    #     io.save_idx(perturbed_images_digit2, Exp.file_roots[intv_no] + Exp.image_labels[0]+".gz")
    #     io.save_idx(perturbing_labels_digit2, Exp.file_roots[intv_no] + Exp.image_labels[0]+"labels.gz")




